Skip to content

[Perf] Support Flashinfer trtllm tinygemm_bf16 router gemm for GPT-OSS#37244

Open
elvischenv wants to merge 3 commits intovllm-project:mainfrom
elvischenv:elvischenv/support-flashinfer-tinygemm
Open

[Perf] Support Flashinfer trtllm tinygemm_bf16 router gemm for GPT-OSS#37244
elvischenv wants to merge 3 commits intovllm-project:mainfrom
elvischenv:elvischenv/support-flashinfer-tinygemm

Conversation

@elvischenv
Copy link
Copy Markdown
Contributor

@elvischenv elvischenv commented Mar 17, 2026

Purpose

Support Flashinfer trtllm tinygemm_bf16 router gemm for GPT-OSS.

Test Plan && Test Result

nsys

PR:

void flashinfer::trtllm_allreduce_fusion::allreduce_fusion_kernel_oneshot_lamport   5.088 μs
void tinygemm_kernel                                                                3.136 μs
void tensorrt_llm::kernels::quantize_with_block_size                                2.848 μs

main:

void flashinfer::trtllm_allreduce_fusion::allreduce_fusion_kernel_oneshot_lamport   5.344 μs
nvjet_sm100_tst_32x64_64x16_4x1_v_bz_splitK_bias_TNN                                3.904 μs
void cublasLt::splitKreduce_kernel                                                  2.816 μs
void tensorrt_llm::kernels::quantize_with_block_size                                3.360 μs

Kernel perf

GPU: NVIDIA B200

gpt-oss-120b: hidden_size=2880, num_experts=128, bias=True
 batch  F.linear(us)  tinygemm(us)   speedup
------------------------------------------
     1      0.0057        0.0034     1.66x
     2      0.0059        0.0034     1.72x
     4      0.0059        0.0034     1.72x
     8      0.0059        0.0034     1.72x
    16      0.0061        0.0034     1.78x
    32      0.0061        0.0034     1.78x
    64      0.0059        0.0036     1.62x
   128      0.0061        0.0036     1.68x
   256      0.0061        0.0059     1.03x
   512      0.0061        0.0104     0.59x

GPU: NVIDIA H100 PCIe

gpt-oss-120b: hidden_size=2880, num_experts=128, bias=True
 batch  F.linear(us)  tinygemm(us)   speedup
------------------------------------------
     1      0.0066        0.0037     1.79x
     2      0.0066        0.0037     1.80x
     4      0.0066        0.0037     1.79x
     8      0.0066        0.0037     1.80x
    16      0.0069        0.0038     1.84x
    32      0.0070        0.0038     1.85x
    64      0.0071        0.0041     1.74x
   128      0.0073        0.0067     1.08x
   256      0.0074        0.0104     0.71x
   512      0.0073        0.0165     0.44x

E2E accuracy

PR:

[{'eval_name': 'gpqa', 'model_name': 'gpt-oss-120b-high_temp1.0_20260316_202021', 'metric': 0.7954545454545454}]

main:

[{'eval_name': 'gpqa', 'model_name': 'gpt-oss-120b-high_temp1.0_20260315_210654', 'metric': 0.7891414141414141}]

E2E perf

PR: about 2% perf gain

============ Serving Benchmark Result ============
Successful requests:                     80
Failed requests:                         0
Maximum request concurrency:             8
Benchmark duration (s):                  29.79
Total input tokens:                      81920
Total generated tokens:                  81920
Request throughput (req/s):              2.69
Output token throughput (tok/s):         2749.91
Peak output token throughput (tok/s):    153.00
Peak concurrent requests:                16.00
Total token throughput (tok/s):          5499.81
---------------Time to First Token----------------
Mean TTFT (ms):                          57.26
Median TTFT (ms):                        65.95
P99 TTFT (ms):                           70.88
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          2.85
Median TPOT (ms):                        2.85
P99 TPOT (ms):                           2.90
---------------Inter-token Latency----------------
Mean ITL (ms):                           56.15
Median ITL (ms):                         56.96
P99 ITL (ms):                            58.44
==================================================

main:

============ Serving Benchmark Result ============
Successful requests:                     80
Failed requests:                         0
Maximum request concurrency:             8
Benchmark duration (s):                  30.68
Total input tokens:                      81920
Total generated tokens:                  81920
Request throughput (req/s):              2.61
Output token throughput (tok/s):         2670.22
Peak output token throughput (tok/s):    144.00
Peak concurrent requests:                16.00
Total token throughput (tok/s):          5340.45
---------------Time to First Token----------------
Mean TTFT (ms):                          74.49
Median TTFT (ms):                        81.17
P99 TTFT (ms):                           102.37
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          2.93
Median TPOT (ms):                        2.92
P99 TPOT (ms):                           2.98
---------------Inter-token Latency----------------
Mean ITL (ms):                           57.55
Median ITL (ms):                         58.38
P99 ITL (ms):                            59.59
==================================================

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the Flashinfer tinygemm_bf16 kernel for the MoE router GEMM in GPT-OSS models. This is achieved by creating a new GateLinear layer with a four-tier dispatch mechanism, where the new Flashinfer kernel is the third tier. The changes are well-implemented and include performance benchmarks showing a ~2% gain. I've identified a minor correctness issue regarding the type hint for the optional bias parameter in the new custom op, which could lead to runtime errors if GateLinear is used without a bias. My suggestions address this.

@elvischenv
Copy link
Copy Markdown
Contributor Author

cc @robertgshaw2-redhat for viz

elvischenv and others added 3 commits March 18, 2026 01:59
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
@elvischenv elvischenv force-pushed the elvischenv/support-flashinfer-tinygemm branch from b27e13c to c4da2c7 Compare March 18, 2026 09:52
@elvischenv
Copy link
Copy Markdown
Contributor Author

@xyang16 I appreciate your review on my PR, and have picked some of your insights, e.g. benchmarked the kernel perf(updated in the PR description) and added batch size limitation.

if (
self.allow_flashinfer_tinygemm_router_gemm
and x.dtype == torch.bfloat16
and x.shape[0] <= 128
Copy link
Copy Markdown
Contributor

@xyang16 xyang16 Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x.shape[0] <= 128 check needs to be put inside the custom op. Otherwise tinygemm will never be launched. Because torch.compile integration does not support runtime dispatching on num_tokens.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it was called correctly from my last test, and got improved perf.
The existed Tier 1 branch also uses this way.

# Tier 1: DSV3 specialized kernel
if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
output = ops.dsv3_router_gemm(
hidden_states=x,
router_weight=self.weight,
output_dtype=self.out_dtype,
)
return output, None

Copy link
Copy Markdown
Contributor

@xyang16 xyang16 Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I profiled your PR with gpt-oss-20b on H200. I don't see tinygemm kernel launched.

If I put the check inside the custom op, I can see tinygemm kernel launched:

void tinygemm_kernel<16, 16, 8, 64, 16, 4, false>(__...         0.00%       0.000us         0.00%       0.000us       0.000us     393.088us         1.51%     393.088us       3.276us           120  

Could you please double check? Thanks!

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 18, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @elvischenv.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 18, 2026
@nvpohanh
Copy link
Copy Markdown
Contributor

@elvischenv could you rebase and fix conflicts? thanks

@nvpohanh
Copy link
Copy Markdown
Contributor

Per offline discussion, we think this has been covered by #37205 and we can close this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

gpt-oss Related to GPT-OSS models needs-rebase

Projects

Status: To Triage

Development

Successfully merging this pull request may close these issues.

3 participants